Skip to content

mamba: shift silu(z) gate from RMSNormGated into selective_state_update#4461

Open
wdykas wants to merge 3 commits intoNVIDIA:mainfrom
wdykas:fix/mamba-gate-shift-silu-into-ssm
Open

mamba: shift silu(z) gate from RMSNormGated into selective_state_update#4461
wdykas wants to merge 3 commits intoNVIDIA:mainfrom
wdykas:fix/mamba-gate-shift-silu-into-ssm

Conversation

@wdykas
Copy link
Copy Markdown
Contributor

@wdykas wdykas commented Apr 24, 2026

In the decode path with rmsnorm=True, the previous code passed z=None to selective_state_update and then applied the gate and normalization together in the gated RMSNorm kernel:

y = selective_state_update(..., z=None, ...)       # SSM writes y
y = self.norm(y, z)                                # silu(z)*y, rmsnorm, weight

selective_state_update already supports applying silu(z)*y inline via its HAS_Z path -- the work then happens in fp32 registers right after the state-C reduction, before y's bf16 round-trip out to HBM. Swapping the gate site lets the downstream RMSNormGated take its z=None fast path (skipping the gate step entirely):

y = selective_state_update(..., z=z_reshaped, ...) # SSM writes silu(z)*y
y = self.norm(y, None)                             # rmsnorm, weight

Net effect:

  • One fewer HBM round-trip of z (SSM reads z in-kernel instead of the post-SSM gated-norm reading y and z separately).
  • Cheaper per-call cost for the gated norm (no gate work).
  • Math is identical up to bf16 rounding; in fact slightly more precise because y no longer round-trips through bf16 between SSM and norm.

Measured on nano-v3 at BS=1, OSL=256, 10 iterations, outlier-trimmed p50:

  • gate_shift_off: 255.0 tok/s p50
  • gate_shift_on: 257.5 tok/s p50 (+1.0%)

What does this PR do ?

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

In the decode path with rmsnorm=True, the previous code passed z=None to
selective_state_update and then applied the gate and normalization together
in the gated RMSNorm kernel:

    y = selective_state_update(..., z=None, ...)       # SSM writes y
    y = self.norm(y, z)                                # silu(z)*y, rmsnorm, weight

selective_state_update already supports applying silu(z)*y inline via its
HAS_Z path -- the work then happens in fp32 registers right after the
state-C reduction, before y's bf16 round-trip out to HBM. Swapping the
gate site lets the downstream RMSNormGated take its z=None fast path
(skipping the gate step entirely):

    y = selective_state_update(..., z=z_reshaped, ...) # SSM writes silu(z)*y
    y = self.norm(y, None)                             # rmsnorm, weight

Net effect:
  - One fewer HBM round-trip of z (SSM reads z in-kernel instead of the
    post-SSM gated-norm reading y and z separately).
  - Cheaper per-call cost for the gated norm (no gate work).
  - Math is identical up to bf16 rounding; in fact slightly more precise
    because y no longer round-trips through bf16 between SSM and norm.

Measured on nano-v3 at BS=1, OSL=256, 10 iterations, outlier-trimmed p50:
  - gate_shift_off: 255.0 tok/s p50
  - gate_shift_on:  257.5 tok/s p50 (+1.0%)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@wdykas wdykas requested review from a team as code owners April 24, 2026 17:02
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft April 24, 2026 17:02
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

@wdykas
Copy link
Copy Markdown
Contributor Author

wdykas commented Apr 24, 2026

This will need to be reviewed and tested much more.

y = self.norm(y, z)
# Gate was already applied inside ``selective_state_update`` via
# HAS_Z, so pass z=None to the gated norm's no-gate fast path.
y = self.norm(y, None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be too crazy to subsume the norm in ssm update as well?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try that this week as well

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ill keep trying but its slower for some reason every time I try. But that could just be a skill issue

@wdykas wdykas marked this pull request as ready for review April 28, 2026 13:26
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team April 28, 2026 13:26
@wdykas
Copy link
Copy Markdown
Contributor Author

wdykas commented Apr 28, 2026

/ok to test e44812f

@wdykas
Copy link
Copy Markdown
Contributor Author

wdykas commented Apr 28, 2026

/ok to test debb995

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants